{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7yuytuIllsv1"
      },
      "source": [
        "# Trax Quick Intro\n",
        "\n",
        "[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the [Google Brain team](https://research.google.com/teams/brain/). This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)) shows how to use Trax and where you can find more information.\n",
        "\n",
        "  1. **Run a pre-trained Transformer**: create a translator in a few lines of code\n",
        "  1. **Features and resources**: [API docs](https://trax-ml.readthedocs.io/en/latest/trax.html), where to [talk to us](https://gitter.im/trax-ml/community), how to [open an issue](https://github.com/google/trax/issues) and more\n",
        "  1. **Walkthrough**: how Trax works, how to make new models and train on your own data\n",
        "\n",
        "We welcome **contributions** to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love **notebooks** that explain how models work and show how to use them to solve problems!\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BIl27504La0G"
      },
      "source": [
        "**General Setup**\n",
        "\n",
        "Execute the following few cells (once) before running any of the code samples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 7779,
          "status": "ok",
          "timestamp": 1596062882791,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "oILRLCWN_16u"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "# Copyright 2020 Google LLC.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "cellView": "both",
        "colab": {
          "height": 34
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 360,
          "status": "ok",
          "timestamp": 1596062883175,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "vlGjGoGMTt-D",
        "outputId": "5afa608b-51a1-47ba-8bec-a20632e44cff"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "/bin/sh: pip: command not found\n"
          ]
        }
      ],
      "source": [
        "#@title\n",
        "# Import Trax\n",
        "\n",
        "!pip install -q -U trax\n",
        "import trax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-LQ89rFFsEdk"
      },
      "source": [
        "## 1. Run a pre-trained Transformer\n",
        "\n",
        "Here is how you create an Engligh-German translator in a few lines of code:\n",
        "\n",
        "* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)\n",
        "* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)\n",
        "* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)\n",
        "* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)\n",
        "* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "height": 34
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 36223,
          "status": "ok",
          "timestamp": 1596062919406,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "djTiSLcaNFGa",
        "outputId": "9ac13b4f-f830-4edf-c61b-d1cc877248ff"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Es ist schön, heute neue Dinge zu lernen!\n"
          ]
        }
      ],
      "source": [
        "\n",
        "# Create a Transformer model.\n",
        "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n",
        "model = trax.models.Transformer(\n",
        "    input_vocab_size=33300,\n",
        "    d_model=512, d_ff=2048,\n",
        "    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n",
        "    max_len=2048, mode='predict')\n",
        "\n",
        "# Initialize using pre-trained weights.\n",
        "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n",
        "                     weights_only=True)\n",
        "\n",
        "# Tokenize a sentence.\n",
        "sentence = 'It is nice to learn new things today!'\n",
        "tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.\n",
        "                                    vocab_dir='gs://trax-ml/vocabs/',\n",
        "                                    vocab_file='ende_32k.subword'))[0]\n",
        "\n",
        "# Decode from the Transformer.\n",
        "tokenized = tokenized[None, :]  # Add batch dimension.\n",
        "tokenized_translation = trax.supervised.decoding.autoregressive_sample(\n",
        "    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.\n",
        "\n",
        "# De-tokenize,\n",
        "tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.\n",
        "translation = trax.data.detokenize(tokenized_translation,\n",
        "                                   vocab_dir='gs://trax-ml/vocabs/',\n",
        "                                   vocab_file='ende_32k.subword')\n",
        "print(translation)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QMo3OnsGgLNK"
      },
      "source": [
        "## 2. Features and resources\n",
        "\n",
        "Trax includes basic models (like [ResNet](https://github.com/google/trax/blob/master/trax/models/resnet.py#L70), [LSTM](https://github.com/google/trax/blob/master/trax/models/rnn.py#L100), [Transformer](https://github.com/google/trax/blob/master/trax/models/transformer.py#L189) and RL algorithms\n",
        "(like [REINFORCE](https://github.com/google/trax/blob/master/trax/rl/training.py#L244), [A2C](https://github.com/google/trax/blob/master/trax/rl/actor_critic_joint.py#L458), [PPO](https://github.com/google/trax/blob/master/trax/rl/actor_critic_joint.py#L209)). It is also actively used for research and includes\n",
        "new models like the [Reformer](https://github.com/google/trax/tree/master/trax/models/reformer) and new RL algorithms like [AWR](https://arxiv.org/abs/1910.00177). Trax has bindings to a large number of deep learning datasets, including\n",
        "[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) and [TensorFlow datasets](https://www.tensorflow.org/datasets/catalog/overview).\n",
        "\n",
        "\n",
        "You can use Trax either as a library from your own python scripts and notebooks\n",
        "or as a binary from the shell, which can be more convenient for training large models.\n",
        "It runs without any changes on CPUs, GPUs and TPUs.\n",
        "\n",
        "* [API docs](https://trax-ml.readthedocs.io/en/latest/)\n",
        "* [chat with us](https://gitter.im/trax-ml/community)\n",
        "* [open an issue](https://github.com/google/trax/issues)\n",
        "* subscribe to [trax-discuss](https://groups.google.com/u/1/g/trax-discuss) for news\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8wgfJyhdihfR"
      },
      "source": [
        "## 3. Walkthrough\n",
        "\n",
        "You can learn here how Trax works, how to create new models and how to train them on your own data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yM12hgQnp4qo"
      },
      "source": [
        "### Tensors and Fast Math\n",
        "\n",
        "The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n",
        "\n",
        "In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "height": 136
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 341,
          "status": "ok",
          "timestamp": 1596062919772,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "kSauPt0NUl_o",
        "outputId": "0ac1e23e-10fe-434b-decc-dca274275fc4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "matrix = \n",
            "[[1 2 3]\n",
            " [4 5 6]\n",
            " [7 8 9]]\n",
            "vector = [1. 1. 1.]\n",
            "product = [12. 15. 18.]\n",
            "tanh(product) = [1. 1. 1.]\n"
          ]
        }
      ],
      "source": [
        "from trax.fastmath import numpy as fastnp\n",
        "trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.\n",
        "\n",
        "matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n",
        "print(f'matrix =\\n{matrix}')\n",
        "vector = fastnp.ones(3)\n",
        "print(f'vector = {vector}')\n",
        "product = fastnp.dot(vector, matrix)\n",
        "print(f'product = {product}')\n",
        "tanh = fastnp.tanh(product)\n",
        "print(f'tanh(product) = {tanh}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "snLYtU6OsKU2"
      },
      "source": [
        "Gradients can be calculated using `trax.fastmath.grad`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "height": 51
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 356,
          "status": "ok",
          "timestamp": 1596062920135,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "cqjYoxPEu8PG",
        "outputId": "80cf3a85-3246-4d54-bea5-60336be9584f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "grad(2x^2) at 1 = 4.0\n",
            "grad(2x^2) at -2 = -8.0\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "  return 2.0 * x * x\n",
        "\n",
        "grad_f = trax.fastmath.grad(f)\n",
        "\n",
        "print(f'grad(2x^2) at 1 = {grad_f(1.0)}')\n",
        "print(f'grad(2x^2) at -2 = {grad_f(-2.0)}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "p-wtgiWNseWw"
      },
      "source": [
        "### Layers\n",
        "\n",
        "Layers are basic building blocks of Trax models. You will learn all about them in the [layers intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html) but for now, just take a look at the implementation of one core Trax layer, `Embedding`:\n",
        "\n",
        "```\n",
        "class Embedding(base.Layer):\n",
        "  \"\"\"Trainable layer that maps discrete tokens/ids to vectors.\"\"\"\n",
        "\n",
        "  def __init__(self,\n",
        "               vocab_size,\n",
        "               d_feature,\n",
        "               kernel_initializer=init.RandomNormalInitializer(1.0)):\n",
        "    \"\"\"Returns an embedding layer with given vocabulary size and vector size.\n",
        "\n",
        "    Args:\n",
        "      vocab_size: Size of the input vocabulary. The layer will assign a unique\n",
        "          vector to each id in `range(vocab_size)`.\n",
        "      d_feature: Dimensionality/depth of the output vectors.\n",
        "      kernel_initializer: Function that creates (random) initial vectors for\n",
        "          the embedding.\n",
        "    \"\"\"\n",
        "    super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')\n",
        "    self._d_feature = d_feature  # feature dimensionality\n",
        "    self._vocab_size = vocab_size\n",
        "    self._kernel_initializer = kernel_initializer\n",
        "\n",
        "  def forward(self, x):\n",
        "    \"\"\"Returns embedding vectors corresponding to input token id's.\n",
        "\n",
        "    Args:\n",
        "      x: Tensor of token id's.\n",
        "\n",
        "    Returns:\n",
        "      Tensor of embedding vectors.\n",
        "    \"\"\"\n",
        "    return jnp.take(self.weights, x, axis=0)\n",
        "\n",
        "  def init_weights_and_state(self, input_signature):\n",
        "    \"\"\"Randomly initializes this layer's weights.\"\"\"\n",
        "    del input_signature\n",
        "    shape_w = (self._vocab_size, self._d_feature)\n",
        "    w = self._kernel_initializer(shape_w, self.rng)\n",
        "    self.weights = w\n",
        "```\n",
        "\n",
        "Layers with trainable weights like `Embedding` need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "height": 51
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 445,
          "status": "ok",
          "timestamp": 1596062920592,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "4MLSQsIiw9Aw",
        "outputId": "baca036b-09ec-43bf-f11f-104674c72306"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]\n",
            "shape of y = (15, 32)\n"
          ]
        }
      ],
      "source": [
        "from trax import layers as tl\n",
        "\n",
        "# Create an input tensor x.\n",
        "x = np.arange(15)\n",
        "print(f'x = {x}')\n",
        "\n",
        "# Create the embedding layer.\n",
        "embedding = tl.Embedding(vocab_size=20, d_feature=32)\n",
        "embedding.init(trax.shapes.signature(x))\n",
        "\n",
        "# Run the layer -- y = embedding(x).\n",
        "y = embedding(x)\n",
        "print(f'shape of y = {y.shape}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MgCPl9ZOyCJw"
      },
      "source": [
        "### Models\n",
        "\n",
        "Models in Trax are built from layers most often using the `Serial` and `Branch` combinators. You can read more about those combinators in the [layers intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html) and\n",
        "see the code for many models in `trax/models/`, e.g., this is how the [Transformer Language Model](https://github.com/google/trax/blob/master/trax/models/transformer.py#L167) is implemented. Below is an example of how to build a sentiment classification model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "height": 119
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 329,
          "status": "ok",
          "timestamp": 1596062920941,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "WoSz5plIyXOU",
        "outputId": "e037e452-dddc-4d53-8b04-229f64322e3f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Serial[\n",
            "  Embedding_8192_256\n",
            "  Mean\n",
            "  Dense_2\n",
            "  LogSoftmax\n",
            "]\n"
          ]
        }
      ],
      "source": [
        "model = tl.Serial(\n",
        "    tl.Embedding(vocab_size=8192, d_feature=256),\n",
        "    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).\n",
        "    tl.Dense(2),      # Classify 2 classes.\n",
        "    tl.LogSoftmax()   # Produce log-probabilities.\n",
        ")\n",
        "\n",
        "# You can print model structure.\n",
        "print(model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FcnIjFLD0Ju1"
      },
      "source": [
        "### Data\n",
        "\n",
        "To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call `next(data_stream)` and get a tuple, e.g., `(inputs, targets)`. Trax allows you to use [TensorFlow Datasets](https://www.tensorflow.org/datasets) easily and you can also get an iterator from your own text file using the standard `open('my_file.txt')`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "height": 54
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 6028,
          "status": "ok",
          "timestamp": 1596062926987,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "pKITF1jR0_Of",
        "outputId": "87971889-ca85-4486-e046-e23b8ddfa7d8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(b\"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.\", 0)\n"
          ]
        }
      ],
      "source": [
        "train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()\n",
        "eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()\n",
        "print(next(train_stream))  # See one example."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fRGj4Skm1kL4"
      },
      "source": [
        "Using the `trax.data` module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using `trax.data.Serial` and they are functions that you apply to streams to create processed streams."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "height": 34
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 2157,
          "status": "ok",
          "timestamp": 1596062929155,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "AV5wrgjZ10yU",
        "outputId": "cae9bf73-9595-409c-cc31-ebb6d68de849"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "shapes = [(8, 1024), (8,), (8,)]\n"
          ]
        }
      ],
      "source": [
        "data_pipeline = trax.data.Serial(\n",
        "    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),\n",
        "    trax.data.Shuffle(),\n",
        "    trax.data.FilterByLength(max_length=2048, length_keys=[0]),\n",
        "    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],\n",
        "                             batch_sizes=[512, 128,  32,    8, 1],\n",
        "                             length_keys=[0]),\n",
        "    trax.data.AddLossWeights()\n",
        "  )\n",
        "train_batches_stream = data_pipeline(train_stream)\n",
        "eval_batches_stream = data_pipeline(eval_stream)\n",
        "example_batch = next(train_batches_stream)\n",
        "print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "l25krioP2twf"
      },
      "source": [
        "### Supervised training\n",
        "\n",
        "When you have the model and the data, use `trax.supervised.training` to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "height": 442
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 58556,
          "status": "ok",
          "timestamp": 1596062987725,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "d6bIKUO-3Cw8",
        "outputId": "448ad018-7cd3-47e8-ff96-277d57b48c0f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Step      1: Ran 1 train steps in 1.06 secs\n",
            "Step      1: train CrossEntropyLoss |  1.05801070\n",
            "Step      1: eval  CrossEntropyLoss |  0.78261878\n",
            "Step      1: eval          Accuracy |  0.52656250\n",
            "\n",
            "Step    500: Ran 499 train steps in 14.93 secs\n",
            "Step    500: train CrossEntropyLoss |  0.56587476\n",
            "Step    500: eval  CrossEntropyLoss |  0.56185426\n",
            "Step    500: eval          Accuracy |  0.71562500\n",
            "\n",
            "Step   1000: Ran 500 train steps in 12.55 secs\n",
            "Step   1000: train CrossEntropyLoss |  0.40789667\n",
            "Step   1000: eval  CrossEntropyLoss |  0.40148986\n",
            "Step   1000: eval          Accuracy |  0.81093750\n",
            "\n",
            "Step   1500: Ran 500 train steps in 12.60 secs\n",
            "Step   1500: train CrossEntropyLoss |  0.36189619\n",
            "Step   1500: eval  CrossEntropyLoss |  0.44177396\n",
            "Step   1500: eval          Accuracy |  0.78007812\n",
            "\n",
            "Step   2000: Ran 500 train steps in 12.63 secs\n",
            "Step   2000: train CrossEntropyLoss |  0.31061822\n",
            "Step   2000: eval  CrossEntropyLoss |  0.35949523\n",
            "Step   2000: eval          Accuracy |  0.84218750\n"
          ]
        }
      ],
      "source": [
        "from trax.supervised import training\n",
        "\n",
        "# Training task.\n",
        "train_task = training.TrainTask(\n",
        "    labeled_data=train_batches_stream,\n",
        "    loss_layer=tl.CrossEntropyLoss(),\n",
        "    optimizer=trax.optimizers.Adam(0.01),\n",
        "    n_steps_per_checkpoint=500,\n",
        ")\n",
        "\n",
        "# Evaluaton task.\n",
        "eval_task = training.EvalTask(\n",
        "    labeled_data=eval_batches_stream,\n",
        "    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],\n",
        "    n_eval_batches=20  # For less variance in eval numbers.\n",
        ")\n",
        "\n",
        "# Training loop saves checkpoints to output_dir.\n",
        "output_dir = os.path.expanduser('~/output_dir/')\n",
        "!rm -rf {output_dir}\n",
        "training_loop = training.Loop(model,\n",
        "                              train_task,\n",
        "                              eval_tasks=[eval_task],\n",
        "                              output_dir=output_dir)\n",
        "\n",
        "# Run 2000 steps (batches).\n",
        "training_loop.run(2000)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-aCkIu3x686C"
      },
      "source": [
        "After training the model, run it like any layer to get results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "height": 71
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 594,
          "status": "ok",
          "timestamp": 1596062988333,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "yuPu37Lp7GST",
        "outputId": "9dc7333e-53d1-4a08-b744-446b04f7e8eb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "example input_str: This movie reeks. No money, no acting, no nothing. I caught this on on the 3am late show movie tonight and felt compelled to comment on it. This movie has nothing to recommend it. I can't believe it ever got released to US television! Nobody in this movie can act their way out of a paperbag. The lame attempts at comedy fall flat on their face, the special effects consist of a worm-like handpuppet \"monster\"... I can't even begin to tell you how rock-bottom this production is. It looks like it cost maybe $50,000 to shoot, but only because it is on 16mm, and that is probably a generous estimate! Anyway, I lost interest rapidly and had to settle for watching \"Matlock\" reruns instead of finishing it. That's how BAD this movie is!!!\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\n",
            "Model returned sentiment probabilities: [[0.9987779  0.00122208]]\n"
          ]
        }
      ],
      "source": [
        "example_input = next(eval_batches_stream)[0][0]\n",
        "example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')\n",
        "print(f'example input_str: {example_input_str}')\n",
        "sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.\n",
        "print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "Trax Quick Intro",
      "provenance": [
        {
          "file_id": "trax/intro.ipynb",
          "timestamp": 1595931762204
        },
        {
          "file_id": "1v1GvTkEFjMH_1c-bdS7JzNS70u9RUEHV",
          "timestamp": 1578964243645
        },
        {
          "file_id": "1SplqILjJr_ZqXcIUkNIk0tSbthfhYm07",
          "timestamp": 1572044421118
        },
        {
          "file_id": "intro.ipynb",
          "timestamp": 1571858674399
        },
        {
          "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt",
          "timestamp": 1569980697572
        },
        {
          "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5",
          "timestamp": 1563927451951
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
